{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e216ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys  \n",
    "import json\n",
    "import torch \n",
    "import argparse\n",
    "import numpy as np\n",
    "from PIL import Image  \n",
    "from tqdm import tqdm\n",
    "from utils import evaluate_relaxed_accuracy, model_gen\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer  \n",
    " \n",
    "\n",
    "ckpt_path = 'internlm/internlm-xcomposer2-vl-7b'\n",
    "tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=\"cuda\", trust_remote_code=True).eval().cuda().half()\n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23c8abd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "samples = json.load(open('data/chartqa/ChartQA Dataset/test/test_human.json')) \n",
    "\n",
    "human_part = []\n",
    "for q in tqdm(samples):\n",
    "    im_path = 'data/chartqa/ChartQA Dataset/test/png/'+q['imgname']\n",
    "    text = '[UNUSED_TOKEN_146]user\\nAnswer the question using a single word or phrase.{}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n'.format(q['query'])\n",
    "    with torch.cuda.amp.autocast():\n",
    "        response = model_gen(model, text, im_path)  \n",
    "    human_part.append({\n",
    "        'answer': response,\n",
    "        'annotation': q['label'] \n",
    "    }) \n",
    "    \n",
    "human_part_acc = evaluate_relaxed_accuracy(human_part)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e0b9f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = json.load(open('data/chartqa/ChartQA Dataset/test/test_augmented.json')) \n",
    "\n",
    "augmented_part = []\n",
    "for q in tqdm(samples):\n",
    "    im_path = 'data/chartqa/ChartQA Dataset/test/png/'+q['imgname']\n",
    "    text = '[UNUSED_TOKEN_146]user\\nAnswer the question using a single word or phrase.{}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n'.format(q['query'])\n",
    "    with torch.cuda.amp.autocast():\n",
    "        response = model_gen(model, text, im_path)  \n",
    "    augmented_part.append({\n",
    "        'answer': response,\n",
    "        'annotation': q['label'] \n",
    "    }) \n",
    "    \n",
    "augmented_part_acc = evaluate_relaxed_accuracy(augmented_part)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7256087",
   "metadata": {},
   "outputs": [],
   "source": [
    "print ((human_part_acc+augmented_part_acc)/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e88949be",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
